{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "ec823036-79d6-4029-aeb8-636c6d663571",
     "showTitle": false,
     "tableResultSettingsMap": {},
     "title": ""
    }
   },
   "source": [
    "# Basic demonstration of LlamaIndex integration functionality with the Unity Catalog AI Toolkit SDK\n",
    "\n",
    "To get started with this, you will need an OpenAI API Key. \n",
    "\n",
    "Once you have acquired your key, set it to the environment variable `OPENAI_API_KEY` after storing it in the `Databricks Secrets` API (accessible via dbutils or the databricks sdk workspace client)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "c7a49c85-025f-4c28-b6ad-b4a0b302769d",
     "showTitle": false,
     "tableResultSettingsMap": {},
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "%pip install -Uqqq unitycatalog-llamaindex[databricks] openai llama_index mlflow\n",
    "\n",
    "%restart_python"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "aa62833c-83ea-47a6-a50b-136ec4ab151e",
     "showTitle": false,
     "tableResultSettingsMap": {},
     "title": ""
    }
   },
   "source": [
    "## Setting your API Key\n",
    "\n",
    "Don't forget to remove the key after you're done running cell 4!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "90d970f0-da57-44b0-816c-90fd9d0de55a",
     "showTitle": false,
     "tableResultSettingsMap": {},
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "import base64\n",
    "import os\n",
    "\n",
    "from databricks.sdk import WorkspaceClient\n",
    "\n",
    "workspace_client = WorkspaceClient()\n",
    "\n",
    "secret_scope = \"ben_wilson\"  # Change me!\n",
    "\n",
    "# Run this if you don't have the API key set to your secrets scope yet\n",
    "\n",
    "# if secret_scope not in [scope.name for scope in workspace_client.secrets.list_scopes()]:\n",
    "#     workspace_client.secrets.create_scope(secret_scope)\n",
    "\n",
    "# my_secret = \"<your API key, temporarily>\"\n",
    "\n",
    "# workspace_client.secrets.put_secret(scope=secret_scope, key=\"openai_api_key\", string_value=my_secret)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "fd54ca3c-b453-4ac6-968f-c47601ca91c1",
     "showTitle": false,
     "tableResultSettingsMap": {},
     "title": ""
    }
   },
   "source": [
    "## Fetch the key and set it to the environment variable key that the OpenAI SDK needs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "2f2ef55b-d3ba-4ddb-af72-ef110d8b9a2e",
     "showTitle": false,
     "tableResultSettingsMap": {},
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "os.environ[\"OPENAI_API_KEY\"] = base64.b64decode(\n",
    "    workspace_client.secrets.get_secret(scope=secret_scope, key=\"openai_api_key\").value\n",
    ").decode()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "43b9d96b-1fc1-483c-93c2-ae9c33461790",
     "showTitle": false,
     "tableResultSettingsMap": {},
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "assert \"OPENAI_API_KEY\" in os.environ, (\n",
    "    \"Please set the OPENAI_API_KEY environment variable to your OpenAI API key\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "433bb01a-7d2b-44c5-ac32-0f0e8b71386a",
     "showTitle": false,
     "tableResultSettingsMap": {},
     "title": ""
    }
   },
   "source": [
    "## Import the UC client for Databricks UC. \n",
    "This will allow for function creation through either the `create_function` API (requires the defined `sql_body` statement) or the `create_python_function` (requires a type-hint-applied and docstring commented python callable). "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "17f134d4-0fe2-4479-8717-3ce183165623",
     "showTitle": false,
     "tableResultSettingsMap": {},
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "from unitycatalog.ai.core.databricks import DatabricksFunctionClient\n",
    "from unitycatalog.ai.llama_index.toolkit import UCFunctionToolkit"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "6b8a211e-b576-4193-a085-6639b39c5f6f",
     "showTitle": false,
     "tableResultSettingsMap": {},
     "title": ""
    }
   },
   "source": [
    "## Set the UC Catalog and Schema \n",
    "\n",
    "You must set both of these that you will be using to store and execute your function(s). If these do not exist, ensure that you create them first. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "3c8d77d3-321a-49b7-a13b-9f4ee73ba726",
     "showTitle": false,
     "tableResultSettingsMap": {},
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "CATALOG = \"ben_wilson\"  # Change me!\n",
    "SCHEMA = \"uc_func\"  # Change me if you want"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "15c65ce9-0464-4a4d-921c-21c149d244f2",
     "showTitle": false,
     "tableResultSettingsMap": {},
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "client = DatabricksFunctionClient()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "478f153f-7106-4b54-940a-0f61fe6e5d79",
     "showTitle": false,
     "tableResultSettingsMap": {},
     "title": ""
    }
   },
   "source": [
    "### Define a Callable\n",
    "The requirements for the callable:\n",
    "\n",
    "**typing**\n",
    "\n",
    "Types **must** be supplied for both the arguments and the return type. Function signatures that do not have these defined will raise a `ValueError`.\n",
    "\n",
    "The following types are not allowed:\n",
    "`Union`\n",
    "`Any`\n",
    "\n",
    "Additional caveats:\n",
    "Collections **must** supply typing of the interior components. For instance, ``typing.Dict`` is not allowed, but ``typing.Dict[str, str]`` will work correctly. \n",
    "\n",
    "**doc strings**\n",
    "\n",
    "The doc string **must** be in the Google Docstring format.\n",
    "Args and Returns comments are optional, but the function description **is required**. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "c85dde3b-626a-4f8b-9ff1-ac6e850469db",
     "showTitle": false,
     "tableResultSettingsMap": {},
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "def fetch_cold_weather(location: str) -> str:\n",
    "    \"\"\"\n",
    "    Fetches the current weather in celsius for a given location.\n",
    "\n",
    "    Args:\n",
    "        location (str): The location to fetch the weather for.\n",
    "\n",
    "    Returns:\n",
    "        str: The current weather in celsius for the given location.\n",
    "    \"\"\"\n",
    "\n",
    "    return \"1.9 C\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "f1d35198-5d37-4dda-9602-cab0648c3afd",
     "showTitle": false,
     "tableResultSettingsMap": {},
     "title": ""
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "FunctionInfo(catalog_name='ben_wilson', comment='Fetches the current weather in celsius for a given location.', created_at=1729720429548, created_by='benjamin.wilson@databricks.com', data_type=<ColumnTypeName.STRING: 'STRING'>, external_language='Python', external_name=None, full_data_type='STRING', full_name='ben_wilson.uc_func.fetch_cold_weather', function_id='40ee2d86-ac4a-4b7e-b3ec-03f03f80ccb5', input_params=FunctionParameterInfos(parameters=[FunctionParameterInfo(name='location', type_text='string', type_name=<ColumnTypeName.STRING: 'STRING'>, position=0, comment='The location to fetch the weather for.', parameter_default=None, parameter_mode=None, parameter_type=<FunctionParameterType.PARAM: 'PARAM'>, type_interval_type=None, type_json='{\"name\":\"location\",\"type\":\"string\",\"nullable\":true,\"metadata\":{\"comment\":\"The location to fetch the weather for.\"}}', type_precision=0, type_scale=0)]), is_deterministic=False, is_null_call=None, metastore_id='19a85dee-54bc-43a2-87ab-023d0ec16013', name='fetch_cold_weather', owner='benjamin.wilson@databricks.com', parameter_style=<FunctionInfoParameterStyle.S: 'S'>, properties='{\"sqlConfig.spark.sql.ansi.enabled\":\"true\",\"sqlConfig.spark.sql.streaming.statefulOperator.stateRebalancing.enabled\":\"false\",\"sqlConfig.spark.sql.legacy.createHiveTableByDefault\":\"false\",\"sqlConfig.spark.sql.shuffleDependency.skipMigration.enabled\":\"true\",\"sqlConfig.spark.sql.streaming.stopTimeout\":\"15s\",\"sqlConfig.spark.sql.readSideCharPadding\":\"true\",\"sqlConfig.spark.sql.variable.substitute\":\"false\",\"sqlConfig.spark.databricks.sql.functions.aiForecast.enabled\":\"true\",\"sqlConfig.spark.sql.sources.default\":\"delta\",\"sqlConfig.spark.sql.hive.convertCTAS\":\"true\",\"sqlConfig.spark.sql.functions.remoteHttpClient.retryOnSocketTimeoutException\":\"true\",\"sqlConfig.spark.sql.sources.commitProtocolClass\":\"com.databricks.sql.transaction.directory.DirectoryAtomicCommitProtocol\",\"sqlConfig.spark.sql.functions.remoteHttpClient.retryOn400TimeoutError\":\"true\",\"sqlConfig.spark.sql.stableDerivedColumnAlias.enabled\":\"true\",\"sqlConfig.spark.sql.parquet.compression.codec\":\"snappy\",\"sqlConfig.spark.sql.streaming.stateStore.providerClass\":\"com.databricks.sql.streaming.state.RocksDBStateStoreProvider\"}', return_params=None, routine_body=<FunctionInfoRoutineBody.EXTERNAL: 'EXTERNAL'>, routine_definition='\\n    return \"1.9 C\"\\n', routine_dependencies=None, schema_name='uc_func', security_type=<FunctionInfoSecurityType.DEFINER: 'DEFINER'>, specific_name='fetch_cold_weather', sql_data_access=<FunctionInfoSqlDataAccess.NO_SQL: 'NO_SQL'>, sql_path=None, updated_at=1729720429548, updated_by='benjamin.wilson@databricks.com')"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "client.create_python_function(func=fetch_cold_weather, catalog=CATALOG, schema=SCHEMA, replace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "ef1f314c-aebd-4014-ac19-f5841fe00902",
     "showTitle": false,
     "tableResultSettingsMap": {},
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "# Create a tool instance to use with LlamaIndex\n",
    "\n",
    "toolkit = UCFunctionToolkit(\n",
    "    function_names=[f\"{CATALOG}.{SCHEMA}.fetch_cold_weather\"], client=client\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {},
     "inputWidgets": {},
     "nuid": "182de212-e14c-46ad-8127-2d1c0558ba14",
     "showTitle": false,
     "tableResultSettingsMap": {},
     "title": ""
    }
   },
   "source": [
    "## Enable tracing in MLflow\n",
    "Auto-enabling tracing allows us to see the calls made by LlamaIndex"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "253d42cc-bffe-42f6-9e0b-e9a033492f97",
     "showTitle": false,
     "tableResultSettingsMap": {},
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "import mlflow\n",
    "\n",
    "mlflow.llama_index.autolog()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "fd08ff30-3b24-4abb-adb2-8841260507e7",
     "showTitle": false,
     "tableResultSettingsMap": {},
     "title": ""
    }
   },
   "source": [
    "## Submit the question\n",
    "\n",
    "In the question request, submit the defined tools from the `UCFunctionToolkit` instance functions that have been fetched. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "071e0891-23ae-4afc-bed8-e0d7ca43b18b",
     "showTitle": false,
     "tableResultSettingsMap": {},
     "title": ""
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "> Running step 727d788f-557a-4d5e-9c76-ac036145121b. Step input: Please call a python execution tool to fetch the weather in Bogota Columbia.\n",
      "\u001b[1;3;38;5;200mThought: The current language of the user is: English. I need to use a tool to help me answer the question.\n",
      "Action: ben_wilson__uc_func__fetch_cold_weather\n",
      "Action Input: {'properties': AttributedDict([('location', 'Bogota, Colombia')])}\n",
      "\u001b[0m\u001b[1;3;34mObservation: {\"format\": \"SCALAR\", \"value\": \"1.9 C\"}\n",
      "\u001b[0m> Running step d44063ae-b5c3-4360-ac26-0926e2c87047. Step input: None\n",
      "\u001b[1;3;38;5;200mThought: I can answer without using any more tools. I'll use the user's language to answer\n",
      "Answer: The current weather in Bogota, Colombia is 1.9 degrees Celsius.\n",
      "\u001b[0m"
     ]
    },
    {
     "data": {
      "text/plain": [
       "AgentChatResponse(response='The current weather in Bogota, Colombia is 1.9 degrees Celsius.', sources=[ToolOutput(content='{\"format\": \"SCALAR\", \"value\": \"1.9 C\"}', tool_name='ben_wilson__uc_func__fetch_cold_weather', raw_input={'args': (), 'kwargs': {'properties': AttributedDict([('location', 'Bogota, Colombia')])}}, raw_output='{\"format\": \"SCALAR\", \"value\": \"1.9 C\"}', is_error=False)], source_nodes=[], is_dummy_stream=False, metadata=None)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "application/databricks.mlflow.trace": "\"tr-1281691e49bb4cf0b9b428efc3f359c9\"",
      "text/plain": [
       "Trace(request_id=tr-1281691e49bb4cf0b9b428efc3f359c9)"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from llama_index.core.agent import ReActAgent\n",
    "from llama_index.llms.openai import OpenAI\n",
    "\n",
    "llm = OpenAI()\n",
    "\n",
    "agent = ReActAgent.from_tools(toolkit.tools, llm=llm, verbose=True)\n",
    "\n",
    "agent.chat(\"Please call a python execution tool to fetch the weather in Bogota Columbia.\")"
   ]
  }
 ],
 "metadata": {
  "application/vnd.databricks.v1+notebook": {
   "dashboards": [],
   "environmentMetadata": {
    "base_environment": "",
    "client": "1"
   },
   "language": "python",
   "notebookMetadata": {
    "pythonIndentUnit": 4
   },
   "notebookName": "llama_index_sample",
   "widgets": {}
  },
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
